import matplotlib

from .hook import AfterIterHook, AfterRunHook

matplotlib.use('agg')
import matplotlib.pyplot as plt


class LossPlotHook(AfterIterHook, AfterRunHook):

    def __init__(self, filename: str, name: str = "LossPlotHook"):
        super().__init__(name)
        self.samples = []
        self.filename = filename

    def after_iter(self, container):
        loss = container.loss.detach().mean().item()
        self.log("%.4f" % loss, container)
        self.samples.append(loss)

    def after_run(self, container):
        plt.title("Loss Curve")
        plt.xlabel("Iteration")
        plt.ylabel("Loss")
        plt.plot(list(range(len(self.samples))), self.samples)
        plt.savefig(self.filename)
